def get_dataset_config_class(dataset_name):
    if dataset_name not in globals():
        raise NotImplementedError("Dataset not found: {}".format(dataset_name))
    return globals()[dataset_name]

class Jango:
    def __init__(self):
        super(Jango, self).__init__()
        self.input_dim = 96
        self.pos_dim = 2
        # self.window_size = 5
        self.window_size = 5

        self.batch_size = 256
        self.cursor_window_size = 8
        self.time_delay = 4

        # self.kld_weight_rec = 1e-1
        self.kld_weight_rec = 1e-2
        self.kld_weight_pos = 1
        # self.rec_weight = 0.5
        self.rec_weight = 1

        self.mse_weight = 1
        self.domain_weight = 1

        # self.grl_weight = 1e-4
        self.grl_weight = 1e-4
        # self.hsic_weight = 1
        self.hsic_weight = 1

class Spike:
    def __init__(self):
        super(Spike, self).__init__()
        self.input_dim = 73
        self.pos_dim = 2
        self.window_size = 6

        self.batch_size = 256
        self.cursor_window_size = 5
        self.time_delay = 2

        self.kld_weight_rec = 1e-1
        self.kld_weight_pos = 1
        self.rec_weight = 0.5

        self.mse_weight = 1
        self.domain_weight = 1

        self.grl_weight = 1e-4
        # self.grl_weight = 0
        self.hsic_weight = 1
        # self.hsic_weight = 0

class Chewie:
    def __init__(self):
        super(Chewie, self).__init__()
        self.input_dim = 96
        self.pos_dim = 2
        self.window_size = 6

        self.batch_size = 256
        self.cursor_window_size = 8
        self.time_delay = 4

        self.kld_weight_rec = 1e-3
        self.kld_weight_pos = 1e-3
        self.rec_weight = 0.5

        self.mse_weight = 1
        self.domain_weight = 1

        # self.grl_weight = 0
        self.grl_weight = 1e-2
        # self.hsic_weight = 1
        self.hsic_weight = 0


class Mihili:
    def __init__(self):
        super(Mihili, self).__init__()
        self.input_dim = 95
        self.pos_dim = 2
        self.window_size = 6

        self.batch_size = 256
        self.cursor_window_size = 5
        self.time_delay = 4

        self.kld_weight_rec = 1e-3
        self.kld_weight_pos = 1e-3
        self.rec_weight = 0.5

        self.mse_weight = 1
        self.domain_weight = 1

        # self.grl_weight = 0
        self.grl_weight = 1e-2
        # self.hsic_weight = 1
        self.hsic_weight = 0

class Mihili_RT:
    def __init__(self):
        super(Mihili_RT, self).__init__()
        self.input_dim = 95
        self.pos_dim = 2
        self.window_size = 5

        self.batch_size = 256
        self.cursor_window_size = 5
        self.time_delay = 2

        self.kld_weight_rec = 1e-3
        self.kld_weight_pos = 1
        self.rec_weight = 0.5

        self.mse_weight = 1
        self.domain_weight = 1

        # self.grl_weight = 0
        self.grl_weight = 1e-2
        # self.hsic_weight = 1e-3
        self.hsic_weight = 0